package t2go

import (
	"fmt"
	"gitee.com/phpdi/cycmd/internal/word"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"sort"
	"strings"
)

//map for converting mysql type to golang types
var (
	typeForMysqlToGo = map[string]string{
		"int":                "int",
		"integer":            "int64",
		"tinyint":            "int64",
		"smallint":           "int64",
		"mediumint":          "int64",
		"bigint":             "int",
		"int unsigned":       "int64",
		"integer unsigned":   "int64",
		"tinyint unsigned":   "int64",
		"smallint unsigned":  "int64",
		"mediumint unsigned": "int64",
		"bigint unsigned":    "int64",
		"bit":                "int64",
		"bool":               "bool",
		"enum":               "string",
		"set":                "string",
		"varchar":            "string",
		"char":               "string",
		"tinytext":           "string",
		"mediumtext":         "string",
		"text":               "string",
		"longtext":           "string",
		"blob":               "string",
		"tinyblob":           "string",
		"mediumblob":         "string",
		"longblob":           "string",
		"date":               "time.Time", // time.Time or string
		"datetime":           "time.Time", // time.Time or string
		"timestamp":          "time.Time", // time.Time or string
		"time":               "time.Time", // time.Time or string
		"float":              "float64",
		"double":             "float64",
		"decimal":            "float64",
		"binary":             "string",
		"varbinary":          "string",
	}
)

type (
	mysqlTableInfo struct {
		db             *gorm.DB
		dsn            string
		info           Info
		tableName      string
		mysqlColIndexs []mysqlColIndex
		cols           []mysqlCol
	}

	mysqlCol struct {
		Field   string //字段名
		Type    string //字段类型
		Key     string //索引类型
		Comment string //字段注释
		Null    string //是否为nul NO,YES
		Default string //默认值
		Extra   string //auto_increment
	}

	mysqlColIndex struct {
		Key_name    string //索引名称
		Non_unique  int    //是否为唯一索引 0=是，1=不是
		Column_name string //字段名
	}

	MysqlTable struct {
		Comment string
	}
)

func NewMysqlTableInfo(dsn string) TableInfo {
	return &mysqlTableInfo{dsn: dsn}
}

func (mti *mysqlTableInfo) GetTableInfo(tableName string) (info Info, err error) {
	mti.tableName = tableName
	if err = mti.init(); err != nil {
		return
	}

	defer func() {
		sqlDB, _ := mti.db.DB()
		_ = sqlDB.Close()
	}()

	// 设置表信息
	if err = mti.setTableInfo(); err != nil {
		return
	}

	// 设置列信息
	if err = mti.setTableColumn(); err != nil {
		return
	}

	// 设置索引信息
	if err = mti.setMysqlColIndex(); err != nil {
		return
	}

	for _, v := range mti.cols {
		mti.info.Columns = append(mti.info.Columns, Column{
			Field:      v.Field,
			GoType:     v.goType(),
			SourceType: v.sourceType(),
			PK:         v.pk(),
			Comment:    v.Comment,
			GormTag:    v.gormTag(mti.mysqlColIndexs),
		})
	}

	return mti.info, nil
}

func (mti *mysqlTableInfo) init() (err error) {
	if mti.db, err = gorm.Open(mysql.Open(mti.dsn), &gorm.Config{}); err != nil {
		return
	}

	return
}

func (mti *mysqlTableInfo) setTableInfo() (err error) {

	type table struct {
		Comment string
	}
	res := table{}
	// 查询表结构
	if err = mti.db.Raw(fmt.Sprintf("SHOW TABLE STATUS FROM `%s` WHERE Name=?", mti.getDbname()), mti.tableName).Scan(&res).Error; err != nil {
		return
	}

	mti.info.Name = mti.tableName
	mti.info.Comment = mti.tableName

	if res.Comment != "" {
		mti.info.Comment = res.Comment
	}

	return
}

func (mti *mysqlTableInfo) setTableColumn() (err error) {

	if err = mti.db.Raw(fmt.Sprintf("SHOW FULL COLUMNS FROM  %s", mti.tableName)).Scan(&mti.cols).Error; err != nil {
		return
	}

	return
}

func (mti *mysqlTableInfo) setMysqlColIndex() (err error) {
	if err = mti.db.Raw(fmt.Sprintf("SHOW INDEX FROM  %s", mti.tableName)).Scan(&mti.mysqlColIndexs).Error; err != nil {
		return
	}

	mti.info.UniqueKey = mti.uniqueKey()
	return
}

func (mti *mysqlTableInfo) getDbname() string {
	arr := strings.Split(mti.dsn, "/")
	if len(arr) != 2 {
		return ""
	}

	arr1 := strings.Split(arr[1], "?")
	if len(arr1) > 0 {
		return arr1[0]
	}

	return ""
}

func (mc mysqlCol) size() string {
	if strings.HasPrefix(mc.Type, "varchar") {
		size := strings.Replace(mc.Type, "varchar", "", 1)
		size = strings.Trim(size, "(")
		size = strings.Trim(size, ")")
		return fmt.Sprintf("size:%s;", size)
	}
	switch typeForMysqlToGo[mc.Type] {
	case "int", "int64":
		return "size:64;"

	}
	return ""
}

func (mc mysqlCol) notnull() string {
	if mc.Null == "NO" {
		return "not null;"
	}

	return ""
}

func (mc mysqlCol) defaultS() string {
	if mc.Default == "" {
		if mc.goType() == "string" {
			return "default:'';"
		}

		return ""
	}

	if mc.goType() == "string" {
		return fmt.Sprintf("default:'%s';", mc.Default)
	}

	return fmt.Sprintf("default:%s;", mc.Default)
}

func (mc mysqlCol) comment() string {
	if mc.Comment != "" {
		return fmt.Sprintf("comment:%s;", mc.Comment)
	}
	return ""
}

func (mc mysqlCol) goType() string {
	if mc.Field == "deleted_at" {
		return "gorm.DeletedAt"
	}

	if arr := strings.Split(mc.Type, "("); len(arr) == 2 {
		return typeForMysqlToGo[arr[0]]
	}

	return typeForMysqlToGo[mc.Type]
}

func (mc mysqlCol) sourceType() string {
	if strings.Contains(mc.Type, "(") {
		if arr := strings.Split(mc.Type, "("); len(arr) == 2 {
			return arr[0]
		}
	}

	return mc.Type
}

func (mc mysqlCol) autoIncrement() string {
	if strings.Contains(mc.Extra, "auto_increment") {
		return "autoIncrement;"
	}
	return ""
}
func (mc mysqlCol) pk() bool {
	return mc.Key == "PRI"
}

func (mc mysqlCol) pkStr() string {
	if mc.pk() {
		return "primaryKey;"
	}

	return ""
}

func (mc mysqlCol) index(mis []mysqlColIndex) (res string) {
	for _, v := range mis {
		if v.Column_name != mc.Field {
			continue
		}

		res += v.index()
	}
	return
}

func (mc mysqlCol) gormTag(mis []mysqlColIndex) string {

	switch mc.Field {
	case "created_at":
		return `gorm:"type:TIMESTAMP;not null;default:CURRENT_TIMESTAMP;comment:创建时间"`
	case "updated_at":
		return `gorm:"type:TIMESTAMP;not null;default:CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP;comment:最后更新时间"`
	case "deleted_at":
		return `gorm:"index;comment:删除时间"`
	}

	arr := []string{mc.pkStr(), mc.autoIncrement(), mc.index(mis), mc.size(), mc.notnull(), mc.defaultS(), mc.comment()}
	return `gorm:"` + strings.Join(arr, "") + `"`
}

func (mi mysqlColIndex) index() string {
	if mi.Key_name == "PRIMARY" {
		return ""
	}

	if mi.Non_unique != 0 {
		return fmt.Sprintf("index:%s;", mi.Key_name)
	}

	return fmt.Sprintf("uniqueIndex:%s;", mi.Key_name)
}

func (mti *mysqlTableInfo) uniqueKey() (res [][]string) {
	m := make(map[string][]string)
	for _, v := range mti.mysqlColIndexs {
		if v.Non_unique != 0 {
			continue
		}

		if v.Key_name == "PRIMARY" {
			continue
		}

		m[v.Key_name] = append(m[v.Key_name], word.UnderscoreToUpperCamelCase(v.Column_name))

	}
	var keys []string
	for k := range m {
		keys = append(keys, k)
	}
	sort.Strings(keys)
	for _, k := range keys {
		res = append(res, m[k])
	}

	return
}
